package jocl.samples;

import static org.jocl.CL.CL_CONTEXT_PLATFORM;
import static org.jocl.CL.CL_DEVICE_NAME;
import static org.jocl.CL.CL_DEVICE_TYPE_ALL;
import static org.jocl.CL.CL_MEM_COPY_HOST_PTR;
import static org.jocl.CL.CL_MEM_READ_ONLY;
import static org.jocl.CL.CL_MEM_READ_WRITE;
import static org.jocl.CL.CL_PLATFORM_NAME;
import static org.jocl.CL.CL_TRUE;
import static org.jocl.CL.clBuildProgram;
import static org.jocl.CL.clCreateBuffer;
import static org.jocl.CL.clCreateCommandQueue;
import static org.jocl.CL.clCreateContext;
import static org.jocl.CL.clCreateKernel;
import static org.jocl.CL.clCreateProgramWithSource;
import static org.jocl.CL.clEnqueueNDRangeKernel;
import static org.jocl.CL.clEnqueueReadBuffer;
import static org.jocl.CL.clGetDeviceIDs;
import static org.jocl.CL.clGetDeviceInfo;
import static org.jocl.CL.clGetPlatformIDs;
import static org.jocl.CL.clGetPlatformInfo;
import static org.jocl.CL.clReleaseCommandQueue;
import static org.jocl.CL.clReleaseContext;
import static org.jocl.CL.clReleaseEvent;
import static org.jocl.CL.clReleaseKernel;
import static org.jocl.CL.clReleaseMemObject;
import static org.jocl.CL.clReleaseProgram;
import static org.jocl.CL.clSetKernelArg;
import static org.jocl.CL.clWaitForEvents;

import java.util.Arrays;

import org.jocl.CL;
import org.jocl.Pointer;
import org.jocl.Sizeof;
import org.jocl.cl_command_queue;
import org.jocl.cl_context;
import org.jocl.cl_context_properties;
import org.jocl.cl_device_id;
import org.jocl.cl_event;
import org.jocl.cl_kernel;
import org.jocl.cl_mem;
import org.jocl.cl_platform_id;
import org.jocl.cl_program;

/**
 * A small JOCL sample that uses multiple devices. <br />
 * <br />
 * Note: This is just a basic demo, showing the possibility to use multiple 
 * devices simultaneously. Each device receives its own copy of the memory 
 * objects to work on. In real applications, there may be a more complex
 * management of the buffers and the synchronization between the different 
 * devices, which is beyond the scope of this sample.
 */
public class JOCLSampleGPUNvidia {
    /**
     * The source code of the OpenCL program to execute, containing 
     * some artificial workload to compute
     */
    private static String programSource = 
        "__kernel void sampleKernel(__global const float *input,"+
        "                           __global float *output, " +
        "                           int size)"+
        "{"+
        "    int gid = get_global_id(0);"+
        "    output[gid] = 0;" +
        "    for (int i=0; i<size; i++) " +
        "        output[gid] += input[i];" +
        "}";
    

    /**
     * The entry point of this sample
     * 
     * @param args Not used
     */
    public static void main(String args[]) {
        // Create input- and output data 
        int n = 50000;
        float input[] = new float[n];
        float output[] = new float[n];
        Arrays.fill(input, 1.123f);

        // The platform and device type that will be used
        // platform[0] = Intel
        // platform[1] = Nvidia
        final int platformIndex = 1;
        final long deviceType = CL_DEVICE_TYPE_ALL;   
        final int deviceIndex = 0; // GPU Nvidia GT630M

        // Enable exceptions and subsequently omit error checks in this sample
        CL.setExceptionsEnabled(true);

        // Obtain the number of platforms
        int numPlatformsArray[] = new int[1];
        clGetPlatformIDs(0, null, numPlatformsArray);
        int numPlatforms = numPlatformsArray[0];

        // Obtain a platform ID
        cl_platform_id platforms[] = new cl_platform_id[numPlatforms];
        clGetPlatformIDs(platforms.length, platforms, null);
        cl_platform_id platform = platforms[platformIndex];
        String platformName = getString(platform, CL_PLATFORM_NAME);
        System.out.println("Using platform "+platformIndex+" of "+ numPlatforms+": "+platformName);

        // Initialize the context properties
        cl_context_properties contextProperties = new cl_context_properties();
        contextProperties.addProperty(CL_CONTEXT_PLATFORM, platform);
        
        // Obtain the number of devices for the platform
        int numDevicesArray[] = new int[1];
        clGetDeviceIDs(platform, deviceType, 0, null, numDevicesArray);
        int numDevices = numDevicesArray[0];
        
        // Obtain a device IDs 
        cl_device_id devices[] = new cl_device_id[numDevices];
        clGetDeviceIDs(platform, deviceType, numDevices, devices, null);
        String deviceName = getString(devices[deviceIndex], CL_DEVICE_NAME);
        System.out.println("Device :" + deviceName);

        // Create a context for the devices
        cl_context context = clCreateContext(
            contextProperties, 1, new cl_device_id[]{devices[deviceIndex]}, null, null, null);

        // Create and build the program and the kernel
        cl_program program = clCreateProgramWithSource(context,
            1, new String[]{ programSource }, null, null);
        clBuildProgram(program, 0, null, null, null, null);
        cl_kernel kernel = clCreateKernel(program, "sampleKernel", null);

        // Allocate the memory objects for the input- and output data
		cl_mem inputMem = clCreateBuffer(context, CL_MEM_READ_ONLY | CL_MEM_COPY_HOST_PTR, 
									  Sizeof.cl_float * n, Pointer.to(input), null);
		
        cl_mem outputMem = clCreateBuffer(context, CL_MEM_READ_WRITE, Sizeof.cl_float * n, null, null);

        // Create command-queue for device
        long properties = 0;
        properties |= CL.CL_QUEUE_PROFILING_ENABLE ;
        cl_command_queue commandQueue = clCreateCommandQueue(context, devices[deviceIndex], properties, null);
        
        // Execute the kernel on each command queue, and 
        // create events for each kernel launch
		clSetKernelArg(kernel, 0, Sizeof.cl_mem, Pointer.to(inputMem));
		clSetKernelArg(kernel, 1, Sizeof.cl_mem, Pointer.to(outputMem));
		clSetKernelArg(kernel, 2, Sizeof.cl_int, Pointer.to(new int[] { n }));
		cl_event event = new cl_event();
		clEnqueueNDRangeKernel(commandQueue, kernel, 1, null, new long[] { n },
				null, 0, null, event);
        
        // waiting for kernel completed
        clWaitForEvents(1, new cl_event[]{event});
        float durationMs = computeDurationMs(event);
        System.out.println("Duration on device "+ durationMs + "ms");
        
        // Read the output data of the first device
        clEnqueueReadBuffer(commandQueue, outputMem, CL_TRUE, 0, n * Sizeof.cl_float, Pointer.to(output), 0, null, null);
        
        // Release kernel, program, and memory objects
        clReleaseKernel(kernel);
        clReleaseProgram(program);
        clReleaseMemObject(inputMem);
        clReleaseMemObject(outputMem);
        clReleaseEvent(event);
        clReleaseCommandQueue(commandQueue);
        clReleaseContext(context);
        
        // Print the first few elements of the result
        for (int i=0; i<10; i++) {
            float x = output[i];
            System.out.print(x+", ");
        }
        System.out.println("...");
        System.out.println("Done");
    }
    
    /**
     * Compute the execution duration of the given event, in milliseconds
     * 
     * @param event The event
     * @return The execution duration, in milliseconds
     */
    private static float computeDurationMs(cl_event event)
    {
        long startTime[] = {0};
        long endTime[] = {0};
        CL.clGetEventProfilingInfo(
            event, CL.CL_PROFILING_COMMAND_START,
            Sizeof.cl_ulong, Pointer.to(startTime), null);
        CL.clGetEventProfilingInfo(
            event, CL.CL_PROFILING_COMMAND_END,
            Sizeof.cl_ulong, Pointer.to(endTime), null);
        long durationNs = endTime[0]-startTime[0];
        return durationNs / 1e6f;
    }
    
    
    /**
     * Returns the value of the platform info parameter with the given name
     *
     * @param platform The platform
     * @param paramName The parameter name
     * @return The value
     */
    private static String getString(cl_platform_id platform, int paramName) {
        long size[] = new long[1];
        clGetPlatformInfo(platform, paramName, 0, null, size);
        byte buffer[] = new byte[(int)size[0]];
        clGetPlatformInfo(platform, paramName, 
            buffer.length, Pointer.to(buffer), null);
        return new String(buffer, 0, buffer.length-1);
    }
    
    /**
     * Returns the value of the device info parameter with the given name
     *
     * @param device The device
     * @param paramName The parameter name
     * @return The value
     */
    private static String getString(cl_device_id device, int paramName) {
        long size[] = new long[1];
        clGetDeviceInfo(device, paramName, 0, null, size);
        byte buffer[] = new byte[(int)size[0]];
        clGetDeviceInfo(device, paramName, 
            buffer.length, Pointer.to(buffer), null);
        return new String(buffer, 0, buffer.length-1);
    }
    
}